--- title: PPO keywords: fastai sidebar: home_sidebar summary: "Proximate Policy Optimization" description: "Proximate Policy Optimization" nb_path: "nbs/11a_agents.policy_gradient.ppo.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
{% endraw %}

PPO is outlined in the paper Proximal Policy Optimization Algorithms. It is an improvement on TRPO where both are on-policy policy gradient RL algorithms.

Originally fastrl had TRPO as an algorithm, however because TRPO was both complicated and out performed by PPO, we only keep PPO this time around.

We will start with the Actor Critic. An important note is that it seems that in the paper there is a way to share the parameters e.i. use a single model as opposed to 2, however for simplicity we will use the Actor Critic architecture.

With this in mind, the Critic is a basic fully connected model for estimating values. The actor will be a probabilistic model.

{% raw %}

class Critic[source]

Critic(state_sz:int, action_sz:int=0, hidden=512) :: Module

Same as nn.Module, but no need for subclasses to call super().__init__

Parameters:

  • state_sz : <class 'int'>

  • action_sz : <class 'int'>, optional

  • hidden : <class 'int'>, optional

{% endraw %} {% raw %}
{% endraw %}

Not so complicated right? The actor will be a little more involved, however I will show how the actor works.

We will start with some utilities that might be generally useful in other models.

{% raw %}

class ClampBlock[source]

ClampBlock(in_sz:int, out_sz:int, activation_cls=Softplus, clamp_min=0.3, clamp_max=10.0) :: Module

Same as nn.Module, but no need for subclasses to call super().__init__

Parameters:

  • in_sz : <class 'int'>

  • out_sz : <class 'int'>

  • activation_cls : <class 'type'>, optional

  • clamp_min : <class 'float'>, optional

  • clamp_max : <class 'float'>, optional

{% endraw %} {% raw %}
{% endraw %} {% raw %}

show_multisurface[source]

show_multisurface(surfaces, title='')

Parameters:

  • surfaces : <class 'inspect._empty'>

  • title : <class 'str'>, optional

{% endraw %} {% raw %}
{% endraw %} {% raw %}
input_dist=torch.randn(10,8)+5
distribution=nn.Linear(8,5)(input_dist)
show_multisurface([distribution,ClampBlock(8,5)(input_dist)],'Regular linear vs Clamped')
{% endraw %} {% raw %}

class ConstBlock[source]

ConstBlock(out_sz:int, constant=1.0, **kwargs) :: Module

Same as nn.Module, but no need for subclasses to call super().__init__

Parameters:

  • out_sz : <class 'int'>

  • constant : <class 'float'>, optional

  • kwargs : <class 'inspect._empty'>

{% endraw %} {% raw %}
{% endraw %}

ClampBlock and ConstBlock are used below for Gaussian Mixture Models. These can also be useful for non-probabilistic implimentations.

{% raw %}
show_multisurface([distribution,ConstBlock(5,constant=20)(input_dist)],'Regular linear vs Constant')
{% endraw %}

You will see above the ConstBlock will simply output a distribution of a single constant. This is good for adding bias's in a calculation.

{% raw %}

clamp_or_const[source]

clamp_or_const(fix_variance:bool, activation_cls=Softplus, clamp_min=0.3, clamp_max=10.0)

Parameters:

  • fix_variance : <class 'bool'>

  • activation_cls : <class 'type'>, optional

  • clamp_min : <class 'float'>, optional

  • clamp_max : <class 'float'>, optional

{% endraw %} {% raw %}
{% endraw %} {% raw %}
test_eq(type(clamp_or_const(fix_variance=True,in_sz=3,out_sz=2)),ConstBlock)
test_eq(type(clamp_or_const(fix_variance=False,in_sz=3,out_sz=2)),ClampBlock)
{% endraw %} {% raw %}

show_surface[source]

show_surface(surface, title='')

Parameters:

  • surface : <class 'inspect._empty'>

  • title : <class 'str'>, optional

{% endraw %} {% raw %}
{% endraw %} {% raw %}

class SimpleGMM[source]

SimpleGMM(in_sz, out_sz, fix_variance:bool=False) :: Module

Same as nn.Module, but no need for subclasses to call super().__init__

Parameters:

  • in_sz : <class 'inspect._empty'>

  • out_sz : <class 'inspect._empty'>

  • fix_variance : <class 'bool'>, optional

{% endraw %} {% raw %}
{% endraw %} {% raw %}
simple_gmm=SimpleGMM(5,30)
dist=simple_gmm(torch.full((10,5),1.0))
show_surface(dist.sample((1,))[0])
{% endraw %} {% raw %}
simple_gmm=SimpleGMM(5,30,fix_variance=True)
dist=simple_gmm(torch.full((10,5),1.0))
show_surface(dist.sample((1,))[0])
{% endraw %} {% raw %}

class MultiCompGMM[source]

MultiCompGMM(in_sz, out_sz, n_components, fix_variance:bool=False, activation_cls=Softplus, clamp_min=0.3, clamp_max=10.0) :: Module

Same as nn.Module, but no need for subclasses to call super().__init__

Parameters:

  • in_sz : <class 'inspect._empty'>

  • out_sz : <class 'inspect._empty'>

  • n_components : <class 'inspect._empty'>

  • fix_variance : <class 'bool'>, optional

  • activation_cls : <class 'type'>, optional

  • clamp_min : <class 'float'>, optional

  • clamp_max : <class 'float'>, optional

{% endraw %} {% raw %}
{% endraw %} {% raw %}
multi_component_gmm=MultiCompGMM(5,30,2)
dist=multi_component_gmm(torch.full((10,5),1.0))
show_surface(dist.sample((1,))[0])
{% endraw %} {% raw %}
multi_component_gmm=MultiCompGMM(5,30,4)
dist=multi_component_gmm(torch.full((10,5),1.0))
show_surface(dist.sample((1,))[0])
{% endraw %}

PPO Implimenation

{% raw %}
from fastai.data.all import *
from fastrl.data.block import *
from fastrl.agent import *
{% endraw %} {% raw %}
GymSrc??
Init signature: GymSrc()
Docstring:      Basic class handling tweaks of a callback loop by changing a `obj` in various events
Source:        
class GymSrc(SrcCallback):
    def initialize(self):
        self.source.histories,self.source.pool=L((deque(maxlen=self.steps_count),
                          gym.make(self.env,**self.env_kwargs))
                          for _ in range(self.n_envs)).zip().map(L)
        if self.source.agent is None:
            test_cb=TstCallback(action_space=self.pool[0].action_space)
            self.source.agent=Agent(cbs=test_cb)
        self.action_is_int=False
       # Extra fields
        self.source.next_state=None
        self.source.state=None
        self.source.image=None
        self.source.action=None
        self.source.episode_num=None
        self.source.accum_rewards=None
        self.init_shapes()
        self('reset')
        if self.mode is not None: self.init_render_shapes()
        self.source.imask=torch.zeros((self.n_envs,)).bool()

    def init_render_shapes(self):
        "Set the image shapes with a batch dim."
        image_shape=_env_render(self.pool[0],self.mode).shape
        self.shape_map['image']=_add_batch(image_shape)

    def init_shapes(self):
        "Set the reward shape, state shapes."
        self.shape_map['reward']=(1,1)
        obs_shape=_add_batch(self.pool[0].observation_space.shape)
        for k in ('state','next_state'): self.shape_map[k]=obs_shape
        action=self.pool[0].action_space.sample()
        if isinstance(action,(int,float)):
            self.shape_map['action']=(1,1)
            if isinstance(action,int): self.action_is_int=True
        else:
            self.shape_map['action']=_add_batch(action)

    def reset(self):
        if self.imask.sum()==0:
            _shape_map={k:v for k,v in self.shape_map.items() if k!='episode_num'}
            reset_exps=D(_shape_map).mapv(_batchwise_zero,bs=self.n_envs)
            for k,v in reset_exps.items(): setattr(self.source,k,v)
            self.source.pool.map(_env_seed,seed=self.seed)
            self.source.state=self.pool.map(_env_reset,shape_map=self.shape_map)
            self.source.state=TensorBatch.vstack(tuple(self.state))
            self.source.done=self.source.done.bool()
            self.source.all_done=self.source.all_done.bool()
            self.source.env_id=self.source.env_id.long()
            self.source.p_id=self.source.p_id.long()
            if self.mode is not None:
                self.source.image=self.pool.map(_env_render,mode=self.mode)
            if self.action_is_int:
                self.source.action=self.source.action.long()

    def after_reset(self):
        if self.imask.sum()==0:
            if self.source.episode_num is None:
                self.source.episode_num=self.source.env_id.detach().clone()+1
            else:
                self.source.episode_num+=self.source.episode_num.max()+1-self.source.episode_num
            self.source.all_exp=deepcopy(self.active_exp(self.n_envs))
            self.source.imask=torch.ones((self.n_envs,)).bool()

    def do_action(self):
        self.source.action,exp=self.agent.do_action(**self.active_exp())
        # TODO: Im not sure we need this section, maybe continuous actions need this here though?
#         self.source.action=_fix_shape_map(self.action,self.shape_map['action'],bs=int(self.imask.sum()))
        if not isinstance(self.action,TensorBatch):
            self.source.action=TensorBatch(self.source.action)
        if self.action_is_int:
            self.source.action=self.source.action.long()
        for k in exp:
            if k=='action': continue
            setattr(self.source,k,exp[k])

    def do_step(self):
        step_res=self.pool[self.imask].zipwith(self.action).starmap(_env_step,
                                                                    shape_map=self.shape_map,
                                                                    is_int=self.action_is_int)
        next_states,rewards,dones=step_res.zip()[:3]
        if self.mode is not None:
            self.source.image=self.pool[self.imask].map(_env_render,mode=self.mode)
        self.source.next_state=TensorBatch.vstack(next_states)
        self.source.reward=TensorBatch.vstack(rewards)
        self.source.accum_rewards+=TensorBatch.vstack(rewards)
        self.source.done=TensorBatch.vstack(dones).bool()
        running_mask=self.imask.nonzero().reshape(-1,1)
        self.source.env_id=TensorBatch(running_mask,bs=self.imask.sum()).long()
        worker_id=get_worker_info()
        worker_id=worker_id.id if worker_id is not None else 0
        self.source.p_id=TensorBatch(torch.full(running_mask.shape,worker_id),bs=self.imask.sum()).long()

    def history(self):
        if self.current_history['all_done'].shape[0]==1 and self.current_history['done'][0][0]:
#             print(self.current_history)
            self.current_history['all_done'][0][0]=True

    def after_history(self):
        active_exp=self.active_exp()
        for k in self.all_exp:
            if k in ['episode_num']:continue
            self.all_exp[k][self.imask]=torch.clone(active_exp[k])
        self.source.imask=~self.all_exp['done'].reshape(-1,)
        for k in self.all_exp:
            if k in ['episode_num']:continue
            setattr(self.source,k,self.all_exp[k][self.imask])
        self.source.state=torch.clone(self.next_state)
        self.source.step+=1
File:           ~/fastrl/fastrl/data/block.py
Type:           type
Subclasses:     
{% endraw %} {% raw %}
model=MultiCompGMM(4,2,1)
agent=Agent(model)
source=Src('Pendulum-v0',agent,seed=None,steps_count=1,n_envs=1,steps_delta=1,cbs=[GymSrc,FirstLast])
{% endraw %} {% raw %}
for x in DataLoader(source,bs=1,n=20):print(x)

TypeErrorTraceback (most recent call last)
<ipython-input-130-b72a54751bac> in <module>
----> 1 for x in DataLoader(source,bs=1,n=20):print(x)

/opt/conda/lib/python3.8/site-packages/fastai/data/load.py in __iter__(self)
    107         self.before_iter()
    108         self.__idxs=self.get_idxs() # called in context of main process (not workers/subprocesses)
--> 109         for b in _loaders[self.fake_l.num_workers==0](self.fake_l):
    110             if self.device is not None: b = to_device(b, self.device)
    111             yield self.after_batch(b)

/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py in __next__(self)
    515             if self._sampler_iter is None:
    516                 self._reset()
--> 517             data = self._next_data()
    518             self._num_yielded += 1
    519             if self._dataset_kind == _DatasetKind.Iterable and \

/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py in _next_data(self)
    555     def _next_data(self):
    556         index = self._next_index()  # may raise StopIteration
--> 557         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    558         if self._pin_memory:
    559             data = _utils.pin_memory.pin_memory(data)

/opt/conda/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     32                 raise StopIteration
     33         else:
---> 34             data = next(self.dataset_iter)
     35         return self.collate_fn(data)
     36 

/opt/conda/lib/python3.8/site-packages/fastai/data/load.py in create_batches(self, samps)
    116         if self.dataset is not None: self.it = iter(self.dataset)
    117         res = filter(lambda o:o is not None, map(self.do_item, samps))
--> 118         yield from map(self.do_batch, self.chunkify(res))
    119 
    120     def new(self, dataset=None, cls=None, **kwargs):

/opt/conda/lib/python3.8/site-packages/fastcore/basics.py in chunked(it, chunk_sz, drop_last, n_chunks)
    214     if not isinstance(it, Iterator): it = iter(it)
    215     while True:
--> 216         res = list(itertools.islice(it, chunk_sz))
    217         if res and (len(res)==chunk_sz or not drop_last): yield res
    218         if len(res)<chunk_sz: return

/opt/conda/lib/python3.8/site-packages/fastai/data/load.py in do_item(self, s)
    131     def prebatched(self): return self.bs is None
    132     def do_item(self, s):
--> 133         try: return self.after_item(self.create_item(s))
    134         except SkipItemException: return None
    135     def chunkify(self, b): return b if self.prebatched else chunked(b, self.bs, self.drop_last)

/opt/conda/lib/python3.8/site-packages/fastai/data/load.py in create_item(self, s)
    139     def create_item(self, s):
    140         if self.indexed: return self.dataset[s or 0]
--> 141         elif s is None:  return next(self.it)
    142         else: raise IndexError("Cannot index an iterable dataset numerically - must use `None`.")
    143     def create_batch(self, b): return (fa_collate,fa_convert)[self.prebatched](b)

~/fastrl/fastrl/data/block.py in __iter__(self)
    119         self('before_episodes')
    120         while True:
--> 121             self._with_events(self._do_reset,'reset',Exception)
    122             self._with_events(self._do_action,'do_action',Exception)
    123             self._with_events(self._do_step,'do_step',Exception)

~/fastrl/fastrl/callback/core.py in _with_events(self, f, event_type, ex, final)
     59     def _with_events(self, f, event_type, ex, final=noop):
     60         try: self(f'before_{event_type}');  f()
---> 61         except ex: self(f'after_cancel_{event_type}')
     62         self(f'after_{event_type}');  final()
     63 

~/fastrl/fastrl/callback/core.py in __call__(self, event_name)
     75 
     76     def ordered_cbs(self, event): return [cb for cb in self.cbs.sorted('order') if hasattr(cb, event)]
---> 77     def __call__(self, event_name): L(event_name).map(self._call_one)
     78 
     79     def _call_one(self, event_name):

/opt/conda/lib/python3.8/site-packages/fastcore/foundation.py in map(self, f, gen, *args, **kwargs)
    152     def range(cls, a, b=None, step=None): return cls(range_of(a, b=b, step=step))
    153 
--> 154     def map(self, f, *args, gen=False, **kwargs): return self._new(map_ex(self, f, *args, gen=gen, **kwargs))
    155     def argwhere(self, f, negate=False, **kwargs): return self._new(argwhere(self, f, negate, **kwargs))
    156     def filter(self, f=noop, negate=False, gen=False, **kwargs):

/opt/conda/lib/python3.8/site-packages/fastcore/basics.py in map_ex(iterable, f, gen, *args, **kwargs)
    664     res = map(g, iterable)
    665     if gen: return res
--> 666     return list(res)
    667 
    668 # Cell

/opt/conda/lib/python3.8/site-packages/fastcore/basics.py in __call__(self, *args, **kwargs)
    649             if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
    650         fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 651         return self.func(*fargs, **kwargs)
    652 
    653 # Cell

~/fastrl/fastrl/callback/core.py in _call_one(self, event_name)
     79     def _call_one(self, event_name):
     80         if not hasattr(self._events, event_name): raise Exception(f'missing {event_name}')
---> 81         for cb in self.cbs.sorted('order'): cb(event_name)
     82 
     83 

~/fastrl/fastrl/callback/core.py in __call__(self, event_name, **kwargs)
    117         res=None
    118         if self.run and _run:
--> 119             res=getattr(self, event_name, noop)(**kwargs)
    120         if event_name==self.end_event: self.run=True #Reset self.run to True at each end of fit
    121         return res

~/fastrl/fastrl/callback/core.py in _with_events(self, f, event_type, ex, final)
     58 
     59     def _with_events(self, f, event_type, ex, final=noop):
---> 60         try: self(f'before_{event_type}');  f()
     61         except ex: self(f'after_cancel_{event_type}')
     62         self(f'after_{event_type}');  final()

~/fastrl/fastrl/data/block.py in _do_reset(self)
     86 
     87     def _init_state(self):     self('initialize')
---> 88     def _do_reset(self):       self('reset')
     89     def _do_action(self):      self('do_action')
     90     def _do_step(self):        self('do_step')

~/fastrl/fastrl/callback/core.py in __call__(self, event_name)
     75 
     76     def ordered_cbs(self, event): return [cb for cb in self.cbs.sorted('order') if hasattr(cb, event)]
---> 77     def __call__(self, event_name): L(event_name).map(self._call_one)
     78 
     79     def _call_one(self, event_name):

/opt/conda/lib/python3.8/site-packages/fastcore/foundation.py in map(self, f, gen, *args, **kwargs)
    152     def range(cls, a, b=None, step=None): return cls(range_of(a, b=b, step=step))
    153 
--> 154     def map(self, f, *args, gen=False, **kwargs): return self._new(map_ex(self, f, *args, gen=gen, **kwargs))
    155     def argwhere(self, f, negate=False, **kwargs): return self._new(argwhere(self, f, negate, **kwargs))
    156     def filter(self, f=noop, negate=False, gen=False, **kwargs):

/opt/conda/lib/python3.8/site-packages/fastcore/basics.py in map_ex(iterable, f, gen, *args, **kwargs)
    664     res = map(g, iterable)
    665     if gen: return res
--> 666     return list(res)
    667 
    668 # Cell

/opt/conda/lib/python3.8/site-packages/fastcore/basics.py in __call__(self, *args, **kwargs)
    649             if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
    650         fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 651         return self.func(*fargs, **kwargs)
    652 
    653 # Cell

~/fastrl/fastrl/callback/core.py in _call_one(self, event_name)
     79     def _call_one(self, event_name):
     80         if not hasattr(self._events, event_name): raise Exception(f'missing {event_name}')
---> 81         for cb in self.cbs.sorted('order'): cb(event_name)
     82 
     83 

~/fastrl/fastrl/callback/core.py in __call__(self, event_name, **kwargs)
    117         res=None
    118         if self.run and _run:
--> 119             res=getattr(self, event_name, noop)(**kwargs)
    120         if event_name==self.end_event: self.run=True #Reset self.run to True at each end of fit
    121         return res

~/fastrl/fastrl/data/block.py in reset(self)
    222         if self.imask.sum()==0:
    223             _shape_map={k:v for k,v in self.shape_map.items() if k!='episode_num'}
--> 224             reset_exps=D(_shape_map).mapv(_batchwise_zero,bs=self.n_envs)
    225             for k,v in reset_exps.items(): setattr(self.source,k,v)
    226             self.source.pool.map(_env_seed,seed=self.seed)

~/fastrl/fastrl/core.py in mapv(self, f, gen, wise, *args, **kwargs)
     45         return self.map(f,*args,gen=gen,wise=wise,**kwargs)
     46     def mapv(self,f,*args,gen=False,wise='value',**kwargs):
---> 47         return self.map(f,*args,gen=gen,wise=wise,**kwargs)
     48 
     49 # Cell

~/fastrl/fastrl/core.py in map(self, f, gen, *args, **kwargs)
     41 
     42     def map(self,f,*args,gen=False,**kwargs):
---> 43         return (self._new,noop)[gen](map_dict_ex(self,f,*args,**kwargs),mapping=True)
     44     def mapk(self,f,*args,gen=False,wise='key',**kwargs):
     45         return self.map(f,*args,gen=gen,wise=wise,**kwargs)

~/fastrl/fastrl/core.py in _new(self, *args, **kwargs)
     38         if with_diff: return eq,set(o.keys()).symmetric_difference(set(self.keys()))
     39         return eq
---> 40     def _new(self,*args,**kwargs): return type(self)(*args,**kwargs)
     41 
     42     def map(self,f,*args,gen=False,**kwargs):

~/fastrl/fastrl/core.py in __init__(self, mapping, *args, **kwargs)
     32     def __init__(self,*args,mapping=False,**kwargs):
     33         self.mapping=mapping
---> 34         super().__init__(*args,**kwargs)
     35 
     36     def eq_k(self,o:'D',with_diff=False):

~/fastrl/fastrl/core.py in <genexpr>(.0)
     23 
     24     if wise is None:  return map(g,d.items())
---> 25     return ((k,g(v)) if wise=='value' else (g(k),v) for k,v in d.items())
     26 
     27 # Cell

/opt/conda/lib/python3.8/site-packages/fastcore/basics.py in __call__(self, *args, **kwargs)
    649             if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
    650         fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 651         return self.func(*fargs, **kwargs)
    652 
    653 # Cell

~/fastrl/fastrl/data/block.py in _batchwise_zero(shape, bs)
    179     return (s,r,d,info)
    180 
--> 181 def _batchwise_zero(shape,bs): return torch.zeros((bs,*shape[1:]))
    182 
    183 class GymSrc(SrcCallback):

TypeError: zeros(): argument 'size' must be tuple of ints, but found element of type numpy.float32 at pos 2
{% endraw %}